import argparse
import pandas as pd
import psycopg2
import torch
import numpy as np

parser = argparse.ArgumentParser(description="Read dataset from PostgreSQL database.")
parser.add_argument('--dbname', required=True, help="Database name.")
parser.add_argument('--user', default='postgres', help="Database user.")
parser.add_argument('--password', required=True, help="Database password.")
parser.add_argument('--host', default='localhost', help="Database host.")
parser.add_argument('--port', default='5432', help="Database port.")
parser.add_argument('--target_table', required=True, help="Target table name.")
parser.add_argument('--target_column', required=True, help="Target column name.")
parser.add_argument('--target_label', required=False, help="Target label value.")
parser.add_argument('--output_dir', required=True, help="Path to save output tensors.")

args = parser.parse_args()

# ------------------------------------------------------------
#  Connect to PostgreSQL
# ------------------------------------------------------------
conn = psycopg2.connect(
    dbname=args.dbname,
    user=args.user,
    password=args.password,
    host=args.host,
    port=args.port
)
cur = conn.cursor()

# ------------------------------------------------------------
# Get all table names in public schema
# ------------------------------------------------------------
cur.execute("""
    SELECT table_name
    FROM information_schema.tables
    WHERE table_schema = 'public';
""")
tables = [t[0] for t in cur.fetchall()]

print(f"Found tables: {tables}")

# ------------------------------------------------------------
#  Load all tables into pandas DataFrames
# ------------------------------------------------------------
dfs = {}
for t in tables:
    df = pd.read_sql(f'SELECT * FROM "{t}";', conn)
    dfs[t] = df
    print(f"Loaded table '{t}' with shape {df.shape}")

# ------------------------------------------------------------
#  Build v_id mapping for all attribute-value pairs
# ------------------------------------------------------------
v_id = {}
for table, df in dfs.items():
    for col in df.columns:
        if col == args.target_column and table == args.target_table:
            continue
        if col not in v_id:
            v_id[col] = {}

i = 0
for table, df in dfs.items():
    for col in df.columns:
        if col not in v_id:
            continue
        for val in df[col].astype(str).unique():
            if val not in v_id[col]:
                v_id[col][val] = i
                i += 1

print(f"Total unique (attribute, value) pairs: {i}")

# ------------------------------------------------------------
# Build hypergraph data
# ------------------------------------------------------------
Hypergraph_data = []
table_id = []
train_test_idx = []

for tid, (table, df) in enumerate(dfs.items()):
    for _, row in df.iterrows():
        edge = []
        for col in df.columns:
            if col not in v_id:
                continue
            edge.append(v_id[col][str(row[col])])
        if table == args.target_table:
            train_test_idx.append(len(Hypergraph_data))
        Hypergraph_data.append(edge)
        table_id.append(tid)

# ------------------------------------------------------------
# Target table for training labels
# ------------------------------------------------------------
train_test_data = []
labels = []

target_df = dfs[args.target_table]

for _, row in target_df.iterrows():
    edge = []
    label = None
    for col in target_df.columns:
        if col not in v_id:
            val = str(row[col]).replace('"', '')
            label = 1.0 if val == args.target_label else 0.0
            label = val
            continue
        edge.append(v_id[col][str(row[col])])
    train_test_data.append(edge)
    labels.append(label)

unique_labels = np.unique(labels)
label_mapping = {label: float(idx) for idx, label in enumerate(unique_labels)}
labels = [label_mapping[l] for l in labels]

# ------------------------------------------------------------
#  Create tensors
# ------------------------------------------------------------
max_tid = max(table_id) + 1
table_mapping_tensor = torch.zeros((len(table_id), max_tid), dtype=torch.float32)
for i, tid in enumerate(table_id):
    table_mapping_tensor[i, tid] = 1
torch.save(table_mapping_tensor, f"{args.output_dir}/table_mapping_tensor.pt")

max_index = max(max(line) for line in Hypergraph_data) + 1

# Hypergraph tensor
indices, values = [], []
for i, edge in enumerate(Hypergraph_data):
    for node in edge:
        indices.append([i, node])
        values.append(1.0)
indices = torch.tensor(indices, dtype=torch.long).t()
values = torch.tensor(values, dtype=torch.float32)
hypergraph_tensor = torch.sparse_coo_tensor(indices, values, (len(Hypergraph_data), max_index))
torch.save(hypergraph_tensor, f"{args.output_dir}/hypergraph_tensor.pt")

# Train/test tensor
indices, values = [], []
for i, row in enumerate(train_test_data):
    for col in row:
        indices.append([i, col])
        values.append(1.0)
indices = torch.tensor(indices, dtype=torch.long).T
values = torch.tensor(values, dtype=torch.float32)
train_test_tensor = torch.sparse_coo_tensor(indices, values, (len(train_test_data), max_index))
torch.save(train_test_tensor, f"{args.output_dir}/train_test_tensor.pt")

# Labels tensor
labels_tensor = torch.tensor(labels, dtype=torch.float32)
torch.save(labels_tensor, f"{args.output_dir}/labels_tensor.pt")

# Feature tensor (identity)
indices = torch.arange(max_index).repeat(2, 1)
values = torch.ones(max_index, dtype=torch.float32)
feature_tensor = torch.sparse_coo_tensor(indices, values, (max_index, max_index))
torch.save(feature_tensor, f"{args.output_dir}/feature_tensor.pt")

torch.save(train_test_idx, f"{args.output_dir}/train_test_idx.pt")

print(" All tensors successfully generated and saved from PostgreSQL.")

